-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][mesh] Mesh fixes #124724
[MLIR][mesh] Mesh fixes #124724
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
ping @sogartar @mfrancio @yaochengji Could you please have a look? |
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Frank Schlimbach (fschlimb) ChangesA collection of fixes to the mesh dialect
@yaochengji @AntonLydike Patch is 48.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/124724.diff 17 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h
new file mode 100644
index 000000000000000..5addffbe571bee1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace arith {
+
+void registerShardingInterfaceExternalModels(DialectRegistry ®istry);
+
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 75cb096130ca6e4..7de7842baf98abf 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -51,7 +51,7 @@ class MeshSharding {
SmallVector<Value> dynamic_sharded_dims_offsets;
public:
- MeshSharding() = default;
+ MeshSharding(::mlir::FlatSymbolRefAttr mesh_ = nullptr);
MeshSharding(Value rhs);
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
ArrayRef<MeshAxesAttr> split_axes_,
@@ -62,7 +62,7 @@ class MeshSharding {
ArrayRef<Value> dynamic_halo_sizes_ = {},
ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
- ::llvm::StringRef getMesh() const { return mesh.getValue(); }
+ ::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; }
ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
ReductionKind getPartialType() const { return partial_type; }
@@ -201,10 +201,13 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
// Insert shard op if there is not one that already has the same sharding.
+// Use newShardOp if it is not null. Otherwise create a new one.
// May insert resharding if required.
-void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
- OpOperand &operand,
- OpBuilder &builder);
+// Return the target ShardOP (new or existing).
+ShardOp maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+ OpOperand &operand,
+ OpBuilder &builder,
+ ShardOp newShardOp);
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
OpBuilder &builder);
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 6039e61a93fadc5..031e6f63bcb42cc 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -28,7 +28,7 @@ class Mesh_Op<string mnemonic, list<Trait> traits = []> :
Op<Mesh_Dialect, mnemonic, traits> {
}
-def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
+def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> {
let summary = "Description of a device/process mesh.";
let description = [{
The mesh.mesh operation is a symbol operation that identifies a specific
@@ -318,12 +318,33 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
"ArrayRef<MeshAxesAttr>":$split_axes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
+ OpBuilder<(ins "llvm::StringRef":$mesh,
+ "ArrayRef<MeshAxesAttr>":$split_axes,
+ CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
+ CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets
+ )>,
OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
];
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
+def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
+ let summary = "Get the sharding of the given tensor.";
+ let description = [{
+ This operation returns the sharding of the given tensor as a MeshSharding.
+ }];
+ let arguments = (ins
+ AnyRankedTensor:$source
+ );
+ let results = (outs
+ Mesh_Sharding:$result
+ );
+ let assemblyFormat = [{
+ $source attr-dict `:` type($source) `->` type($result)
+ }];
+}
+
def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
let summary = "Get the shard shape of a given process/device.";
let description = [{
@@ -460,6 +481,7 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
(`annotate_for_users` $annotate_for_users^)?
attr-dict `:` type($result)
}];
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index b4d25cef05a7b96..14aad7f9f6783d9 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -36,7 +36,9 @@ struct ShardingOption {
bool empty = false;
ShardingOption() = default;
ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
- : shardingArray(std::move(shardingArray)), mesh(mesh) {}
+ : shardingArray(std::move(shardingArray)), mesh(mesh) {
+ assert(this->mesh);
+ }
static ShardingOption makeEmpty() {
auto res = ShardingOption();
res.empty = true;
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 0da82825c82878a..33bc89279c08c32 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -23,6 +23,7 @@
#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
@@ -158,6 +159,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
+ arith::registerShardingInterfaceExternalModels(registry);
arith::registerValueBoundsOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 6149b35befe7de2..f96bda603baa63d 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRArithTransforms
ExpandOps.cpp
IntRangeOptimizations.cpp
ReifyValueBounds.cpp
+ ShardingInterfaceImpl.cpp
UnsignedWhenEquivalent.cpp
ADDITIONAL_HEADER_DIRS
@@ -26,7 +27,9 @@ add_mlir_dialect_library(MLIRArithTransforms
MLIRInferIntRangeInterface
MLIRIR
MLIRMemRefDialect
+ MLIRMeshDialect
MLIRPass
+ MLIRShardingInterface
MLIRTensorDialect
MLIRTransforms
MLIRTransformUtils
diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
new file mode 100644
index 000000000000000..f31db4906775687
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
@@ -0,0 +1,99 @@
+//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+using namespace mlir::arith;
+using namespace mlir::mesh;
+
+namespace {
+
+// Sharding of arith.constant
+struct ConstantShardingInterface
+ : public ShardingInterface::ExternalModel<ConstantShardingInterface,
+ ConstantOp> {
+ SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+ auto ndims = 0;
+ if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+ ndims = type.getRank();
+ }
+ return SmallVector<utils::IteratorType>(ndims,
+ utils::IteratorType::parallel);
+ }
+
+ SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+ if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+ return SmallVector<AffineMap>(1, {AffineMap::getMultiDimIdentityMap(
+ type.getRank(), op->getContext())});
+ }
+ return {};
+ }
+
+ // Indicate failure if no result sharding exists.
+ // Otherwise mirror result sharding if it is a tensor constant.
+ // Otherwise return replication option.
+ FailureOr<ShardingOption>
+ getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings) const {
+ if (!resultShardings[0]) {
+ return failure();
+ }
+ if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+ ShardingArray axesArray(resultShardings[0].getSplitAxes().size());
+ for (auto [i, axes] :
+ llvm::enumerate(resultShardings[0].getSplitAxes())) {
+ axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
+ }
+ return ShardingOption(axesArray, resultShardings[0].getMeshAttr());
+ }
+ return ShardingOption({}, resultShardings[0].getMeshAttr());
+ }
+
+ LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings,
+ IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
+ auto cOp = cast<ConstantOp>(op);
+ auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue());
+ if (value) {
+ if (!value.isSplat() || !resultShardings[0]) {
+ // Currently non-splat constants are not supported.
+ return failure();
+ }
+ auto sharding = resultShardings[0];
+ auto newType = cast<RankedTensorType>(shardType(
+ cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable),
+ sharding));
+ auto newValue = value.resizeSplat(newType);
+ auto newOp = builder.create<ConstantOp>(op->getLoc(), newType, newValue);
+ spmdizationMap.map(op->getResult(0), newOp.getResult());
+ spmdizationMap.map(op, newOp.getOperation());
+ } else {
+ // `clone` will populate the mapping of old to new results.
+ (void)builder.clone(*op, spmdizationMap);
+ }
+ return success();
+ }
+};
+} // namespace
+
+void mlir::arith::registerShardingInterfaceExternalModels(
+ DialectRegistry ®istry) {
+
+ registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
+ ConstantOp::template attachInterface<ConstantShardingInterface>(*ctx);
+ });
+}
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 33460ff25e9e45d..c789fc527e3f680 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -194,6 +194,12 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
const SplitAxes &splitAxes, OutShape &outShape,
ArrayRef<int64_t> shardedDimsOffsets = {},
ArrayRef<int64_t> haloSizes = {}) {
+ // 0d tensors cannot be sharded and must get replicated
+ if (inShape.empty()) {
+ assert(outShape.empty());
+ return;
+ }
+
std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
llvm::adl_begin(outShape));
@@ -269,9 +275,10 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
return type;
}
-void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
- OpOperand &operand,
- OpBuilder &builder) {
+ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+ OpOperand &operand,
+ OpBuilder &builder,
+ ShardOp newShardOp) {
OpBuilder::InsertionGuard insertionGuard(builder);
Value operandValue = operand.get();
Operation *operandOp = operand.getOwner();
@@ -279,14 +286,17 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
if (shardOp && sharding == shardOp.getSharding() &&
!shardOp.getAnnotateForUsers()) {
- // No need for anything the correct sharding is already set.
- return;
+ // No need for anything if the correct sharding is already set.
+ return newShardOp ? newShardOp : shardOp;
}
- auto shardingOp = builder.create<ShardingOp>(operandValue.getLoc(), sharding);
- auto newShardOp =
- builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
- /*annotate_for_users*/ false);
+ if (!newShardOp) {
+ auto shardingOp =
+ builder.create<ShardingOp>(operandValue.getLoc(), sharding);
+ newShardOp =
+ builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
+ /*annotate_for_users*/ false);
+ }
IRRewriter rewriter(builder);
rewriter.replaceUsesWithIf(
operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
@@ -294,20 +304,23 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
});
if (!shardOp || shardOp.getAnnotateForUsers()) {
- return;
+ return newShardOp;
}
- auto newShardOp2 =
- builder.create<ShardOp>(operandValue.getLoc(), newShardOp, shardingOp,
- /*annotate_for_users*/ true);
+ auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
+ newShardOp.getSharding(),
+ /*annotate_for_users*/ true);
rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
+ return newShardOp;
}
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpResult result,
OpBuilder &builder) {
+ ShardOp newShardOp;
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
- maybeInsertTargetShardingAnnotation(sharding, use, builder);
+ newShardOp =
+ maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
}
}
@@ -316,9 +329,18 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
OpBuilder &builder) {
OpBuilder::InsertionGuard insertionGuard(builder);
Value operandValue = operand.get();
- Operation *operandOp = operand.getOwner();
Operation *operandSrcOp = operandValue.getDefiningOp();
bool isBlockArg = !operandSrcOp;
+ {
+ auto opType = dyn_cast<mlir::RankedTensorType>(operandValue.getType());
+ assert(!opType || opType.getRank() > 0 || isFullReplication(sharding));
+ }
+ if (!isa<RankedTensorType>(operandValue.getType()) && operandSrcOp &&
+ operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
+ return;
+ }
+
+ Operation *operandOp = operand.getOwner();
ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
if (shardOp && sharding == shardOp.getSharding() &&
@@ -432,16 +454,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
ArrayRef<MeshAxesAttr> split_axes,
ArrayRef<MeshAxis> partial_axes,
mesh::ReductionKind partial_type,
- ArrayRef<int64_t> static_halo_sizes,
- ArrayRef<int64_t> static_sharded_dims_offsets) {
+ ArrayRef<int64_t> static_halos,
+ ArrayRef<int64_t> static_offsets) {
return build(
b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
- ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
- ::mlir::DenseI64ArrayAttr::get(b.getContext(),
- static_sharded_dims_offsets),
- {});
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
}
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
@@ -453,6 +473,18 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
{}, {}, {}, {});
}
+void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
+ llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
+ ArrayRef<int64_t> static_halos,
+ ArrayRef<int64_t> static_offsets) {
+ return build(
+ b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
+ MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
+ ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
+}
+
void ShardingOp::build(
::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
@@ -579,9 +611,10 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
namespace {
// Sharding annotations "halo sizes" and "sharded dims offsets"
// are a mix of attributes and dynamic values. This canonicalization moves
-// constant values to the respective attribute lists and so minimizes the number
+// constant values to the respective attribute lists, minimizing the number
// of values.
-class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
+// It also removes sharded_dims_sizes and halos if they are effectively "empty".
+class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
public:
using OpRewritePattern<ShardingOp>::OpRewritePattern;
@@ -593,14 +626,41 @@ class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
op.getDynamicShardedDimsOffsets(), b);
// No constant operands were folded, just return;
- if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) &&
- failed(foldDynamicIndexList(mixedOffs, /*onlyNo...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments on the first patch.
Thanks @rengolin for your thorough review and comments. I made the modifications as you suggested. |
Thanks! So far so good. Just reviewed the final commits and left some comments. Should be good with those addressed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Frank, this looks really good!
Attaching ShardingInterface to arith::ConstantOp
and handling GetShardingOp in ShardingPropagation
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/153/builds/22645 Here is the relevant piece of the build log for the reference
|
A collection of fixes to the mesh dialect - allow constants in sharding propagation/spmdization - fixes to tensor replication (e.g. 0d tensors) - improved canonicalization - sharding propagation incorrectly generated too many ShardOps New operation `mesh.GetShardOp` enables exchanging sharding information (like on function boundaries)
A collection of fixes to the mesh dialect - allow constants in sharding propagation/spmdization - fixes to tensor replication (e.g. 0d tensors) - improved canonicalization - sharding propagation incorrectly generated too many ShardOps New operation `mesh.GetShardOp` enables exchanging sharding information (like on function boundaries)
A collection of fixes to the mesh dialect - allow constants in sharding propagation/spmdization - fixes to tensor replication (e.g. 0d tensors) - improved canonicalization - sharding propagation incorrectly generated too many ShardOps New operation `mesh.GetShardOp` enables exchanging sharding information (like on function boundaries)
A collection of fixes to the mesh dialect
New operation
mesh.GetShardOp
enables exchanging sharding information (like on function boundaries)@yaochengji @AntonLydike